clear

X = [
1 0 0 0 0 0;
1 0 1 0 0 0;
1 0 0 0 0 0;
0 1 1 0 0 0;
0 1 1 0 0 0;
0 0 0 1 0 1;
0 0 0 1 0 1;
0 0 0 0 1 1;
0 0 0 0 1 0
];

numDocs = size(X,2);
numWords = size(X,1);
numTopics = 2;

%P(w_j|z_k)
w_z = rand(numWords, numTopics);
for i=1:numTopics
    w_z(:,i) = w_z(:,i)/sum(w_z(:,i));
end
%P(z_k|d_i)
d_z = rand(numDocs, numTopics);
for i=1:numTopics
    d_z(:,i) = d_z(:,i)/sum(d_z(:,i));
end

%P(z_k)
z = eye(numTopics);
for i=1:numTopics
    z(i,i) = 1/numTopics;
end

%P(z_k|d_i,w_j)
z_dw = zeros(numTopics, numDocs, numWords);

convergence = 1e6;
while(convergence > 0.01)
%for iter=1:1
    for k=1:numTopics
        for i=1:numDocs
            for j=1:numWords
                numerator = w_z(j,k) * d_z(i,k) * z(k,k);
                denominator = 0;
                for l=1:numTopics
                    denominator = denominator + (w_z(j,l) * d_z(i,l) * z(l,l));
                end
                z_dw(k,i,j) = numerator / denominator;
            end
        end
    end
    
    w_z_old = w_z;
    for j=1:numWords
        for k=1:numTopics
            numerator = 0;
            for i=1:numDocs
                numerator = numerator + (X(j,i) * z_dw(k,i,j));
            end
            denominator = 0;
            for m=1:numWords
                for i=1:numDocs
                    denominator = denominator + (X(m,i) * z_dw(k,i,m));
                end
            end
            w_z(j,k) = numerator / denominator;
        end
    end
    
    convergence = sum(sum(abs(w_z - w_z_old)))
    for i=1:numDocs
        for k=1:numTopics
            numerator = 0;
            for j=1:numWords
                numerator = numerator + ( X(j,i) * z_dw(k,i,j) );
            end
            denominator = 0;
            for n=1:numDocs
                for j=1:numWords
                    denominator = denominator + (X(j,n) * z_dw(k,n,j));
                end
            end
            d_z(i,k) = numerator / denominator;
        end
    end
    
    for k=1:numTopics
        numerator = 0;
        denominator = 0;
        for i=1:numDocs
            for j=1:numWords
                numerator = numerator + (X(j,i) * z_dw(k,i,j));
                denominator = denominator + X(j,i);
            end
        end
        z(k,k) = numerator/denominator;
    end
end

w_z
d_z
z